# Copyright 2025 ZTE Corporation.
# All Rights Reserved.
#
#    Licensed under the Apache License, Version 2.0 (the "License"); you may
#    not use this file except in compliance with the License. You may obtain
#    a copy of the License at
#
#         http://www.apache.org/licenses/LICENSE-2.0
#
#    Unless required by applicable law or agreed to in writing, software
#    distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
#    WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
#    License for the specific language governing permissions and limitations
#    under the License.

import logging
import os
import platform
import warnings
from typing import Any, List, Optional, Type, Union

from langchain_core.callbacks import (
    CallbackManagerForToolRun,
)
from langchain_core.tools import BaseTool
from pydantic import BaseModel, Field, model_validator

logger = logging.getLogger(__name__)


class ShellInput(BaseModel):
    """Commands for the Bash Shell tool."""

    commands: Union[str, List[str]] = Field(
        ...,
        description="List of shell commands to run. Deserialized using json.loads",
    )
    """List of shell commands to run."""

    @model_validator(mode="before")
    @classmethod
    def _validate_commands(cls, values: dict) -> Any:
        """Validate commands."""
        # TODO: Add real validators
        commands = values.get("commands")
        if not isinstance(commands, list):
            values["commands"] = [commands]
        # Warn that the bash tool is not safe
        warnings.warn(
            "The shell tool has no safeguards by default. Use at your own risk."
        )
        return values


def _get_default_bash_process() -> Any:
    """Get default bash process."""
    try:
        from langchain_experimental.llm_bash.bash import BashProcess
    except ImportError:
        raise ImportError(
            "BashProcess has been moved to langchain experimental."
            "To use this tool, install langchain-experimental "
            "with `pip install langchain-experimental`."
        )
    return BashProcess(return_err_output=True)


def _get_platform() -> str:
    """Get platform."""
    system = platform.system()
    if system == "Darwin":
        return "MacOS"
    return system


class ShellTool(BaseTool):  # type: ignore[override, override]
    """Tool to run shell commands."""

    process: Any = Field(default_factory=_get_default_bash_process)
    """Bash process to run commands."""

    name: str = "terminal"
    """Name of tool."""

    description: str = f"Run shell commands on this {_get_platform()} machine."
    """Description of tool."""

    args_schema: Type[BaseModel] = ShellInput
    """Schema for input arguments."""

    ask_human_input: bool = False
    """
    If True, prompts the user for confirmation (y/n) before executing 
    a command generated by the language model in the bash shell.
    """

    def _run(
        self,
        commands: Union[str, List[str]],
        run_manager: Optional[CallbackManagerForToolRun] = None,
    ) -> str:
        """Run commands and return final output."""

        print(f"Executing command:\n {commands}")  # noqa: T201

        try:
            if self.ask_human_input:
                user_input = input("Proceed with command execution? (y/n): ").lower()
                if user_input == "y":
                    return self.process.run(commands)
                else:
                    logger.info("Invalid input. User aborted command execution.")
                    return None  # type: ignore[return-value]
            else:
                return self.process.run(commands)

        except Exception as e:
            logger.error(f"Error during command execution: {e}", exc_info=True)
            return None  # type: ignore[return-value]

    def execute(self, command: str) -> str:
        r"""Execute a shell command and return the output.

        Args:
            command (str): The shell command to execute.

        Returns:
            str: The output of the command or an error message.
        """
        # List of dangerous commands/patterns to block
        NOT_DANGEROUS_COMMANDS = [
            'ls ', 'pip ','cat '
        ]

        # Check for dangerous commands
        for dangerous_cmd in NOT_DANGEROUS_COMMANDS:
            if dangerous_cmd in command:
                self.ask_human_input = False
                try:
                    result = self._run(command)
                    if result is None:
                        return "Command execution was aborted or failed."
                    return result
                except Exception as e:
                    return f"Error executing command: {str(e)}"
        return f"Error: Command blocked for security reasons"



if __name__=="__main__":
    # 创建ShellTool实例
    shell_tool = ShellTool()
    # 执行简单命令（无需确认）
    output = shell_tool.execute("playwright install")
    print(output)

    # 处理错误情况
    output = shell_tool.execute("invalid_command")
    print(output)